{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# StableDiffusion模型LoRA微调\n",
    "\n",
    "[StableDiffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)是由StabilityAI、CompVis与Runway合作开发并开源的文本生成图像的模型。他可以直接用于文本生成图像的任务，也可以作为基础模型进行微调，从而从数据集上学习到新的风格，或是用于完成新的任务。本文将介绍通过在PAI完成LoRA微调StableDiffusion模型。\n",
    "\n",
    "## 背景介绍\n",
    "\n",
    "[LoRA（Low-Rank Adaption of Large Language Model）](https://arxiv.org/abs/2106.09685)是由微软提出的高效微调大语言模型的方法，他通过冻结原始模型参数，在模型上新增低秩矩阵作为可训练参数的方式微调模型。研究者发现，通过在Transformer块的Attention层上添加LoRA低秩矩阵对模型进行微调，能够获得与全参数微调水平相近的模型。相比于全参数的微调，LoRA有以下优点：\n",
    "\n",
    "- 训练的参数量小，计算资源消耗低，训练速度更快。\n",
    "  \n",
    "- 对于计算资源/显存的要求更低，支持用户在消费级/中低端的GPU卡对大模型进行微调。\n",
    "\n",
    "- 冻结了原始模型参数，在训练过程中不容易发生灾难性遗忘。\n",
    "\n",
    "- 产出的模型较小，存储的成本较低，仅需推理时和原始的模型一同使用进行推理。\n",
    "\n",
    "后续有开发者，将其引入到[StableDiffsion模型的微调](https://github.com/cloneofsimo/lora)中，取得了不错的效果。HuggingFace提供的[Diffusers库](https://github.com/huggingface/diffusers)支持用户使用扩散模型进行训练或是推理，他支持用户使用LoRA微调扩散模型，并提供了相应的训练代码，支持[文生图](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)，以及[DreamBooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py)的LoRA训练。\n",
    "\n",
    "当前示例，我们将基于[Diffusers库提供的训练代码和文档](https://huggingface.co/docs/diffusers/training/overview)，在PAI完成StableDiffusion v1.5模型的LoRA微调训练。\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 准备工作\n",
    "\n",
    "### 安装PAI Python SDK\n",
    "\n",
    "安装PAI Python SDK，用于提交训练任务到PAI。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 安装PAI Python SDK\n",
    "!python -m pip install --upgrade alipai"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "SDK需要配置访问阿里云服务需要的AccessKey，以及当前使用的工作空间和OSS Bucket。在PAI SDK安装之后，通过在**命令行终端** 中执行以下命令，按照引导配置密钥、工作空间等信息。\n",
    "\n",
    "\n",
    "```shell\n",
    "\n",
    "# 以下命令，请在 命令行终端 中执行.\n",
    "\n",
    "python -m pai.toolkit.config\n",
    "\n",
    "```\n",
    "\n",
    "我们可以执行以下代码，验证配置是否成功。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pai\n",
    "from pai.session import get_default_session\n",
    "\n",
    "print(pai.__version__)\n",
    "\n",
    "sess = get_default_session()\n",
    "\n",
    "# 配置成功之后，我们可以拿到工作空间的信息\n",
    "assert sess.workspace_name is not None\n",
    "assert sess.oss_bucket is not None"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 获取PAI提供的StableDiffusion模型\n",
    "\n",
    "PAI的公共模型仓库提供了StableDiffusion v1.5模型，用户可以通过以下代码获取模型的信息，用于后续的微调训练。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.session import get_default_session\n",
    "from pai.libs.alibabacloud_aiworkspace20210204.models import ListModelsRequest\n",
    "\n",
    "sess = get_default_session()\n",
    "\n",
    "# 获取PAI提供的StableDiffusion模型信息\n",
    "resp = sess._acs_workspace_client.list_models(\n",
    "    request=ListModelsRequest(\n",
    "        provider=\"pai\",\n",
    "        model_name=\"stable_diffusion_v1.5\",\n",
    "    )\n",
    ")\n",
    "model = resp.body.models[0].latest_version\n",
    "\n",
    "# StableDiffusion 模型的OSS URI（公共读）\n",
    "print(f\"StableDiffusion ModelUri: {model.uri}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LoRA TextToImage微调训练\n",
    "\n",
    "通过LoRA训练StableDiffusion模型，可以快速，低成本地获得一个能够生成指定风格的模型。在以下示例中，我们将使用一个Demo的图像文本数据集，对StableDiffusion模型进行LoRA微调。\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 准备训练数据\n",
    "\n",
    "当前示例准备了一个简单的文本图片数据集在`train-data`目录下，包含训练的图片以及相应的标注文件(`metadata.jsonl`)。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "!ls -lh train-data/\n",
    "!cat train-data/metadata.jsonl"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们需要将数据上传到OSS Bucket上，供训练作业使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.common.oss_utils import upload\n",
    "\n",
    "train_data_uri = upload(\"./train-data/\", \"stable_diffusion_demo/text2image/train-data/\")\n",
    "print(train_data_uri)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Diffuerser提供的训练脚本默认使用[ImageFolder](https://huggingface.co/docs/datasets/en/image_dataset#imagefolder)格式的数据集，用户可以参考以上的格式准备数据，更加详细的介绍可以见HuggingFace datasets的[ImageFolder数据集文档](https://huggingface.co/docs/datasets/en/image_dataset#imagefolder)。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### 准备训练作业脚本\n",
    "\n",
    "我们将使用Diffusers库提供的[训练作业脚本(train_text_to_image_lora.py)](https://github.com/huggingface/diffusers/blob/v0.17.1/examples/text_to_image/train_text_to_image_lora.py)完成LoRA训练。执行以下代码，我们将代码下载到本地，用于后续提交训练任务。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p train_lora\n",
    "\n",
    "# code source: https://github.com/huggingface/diffusers/blob/v0.17.1/examples/text_to_image/train_text_to_image_lora.py\n",
    "!wget -P train_lora https://raw.githubusercontent.com/huggingface/diffusers/v0.17.1/examples/text_to_image/train_text_to_image_lora.py"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们提交的训练作业将使用PAI提供的PyTorch 1.12的GPU镜像运行，我们需要准备一个`requirements.txt`文件在训练代码目录下，以安装一些额外的依赖包。\n",
    "\n",
    "训练脚本目录提交到PAI上执行训练时，目录下的`requirements.txt`文件将被安装到作业环境中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile train_lora/requirements.txt\n",
    "\n",
    "diffusers>=0.17.1\n",
    "\n",
    "# source: https://github.com/huggingface/diffusers/blob/v0.17.1/examples/text_to_image/requirements.txt\n",
    "accelerate>=0.16.0,<=0.18.0\n",
    "torchvision\n",
    "transformers>=4.25.1,<5.0.0\n",
    "datasets\n",
    "ftfy\n",
    "tensorboard\n",
    "Jinja2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 提交训练作业\n",
    "\n",
    "Diffuers提供的训练脚本，需要使用`Accelerate`工具启动，并通过命令行参数的方式，传递超参，预训练模型路径/ID，以及训练数据集地址。PAI的训练作业，支持通过环境变量的方式获取输入输出的数据/模型路径，以及训练作业超参。以下脚本中，我们通过环境变量的方式，传递超参、输入输出路径给到训练脚本。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.image import retrieve\n",
    "\n",
    "# 使用PAI提供的PyTorch 1.12 GPU镜像\n",
    "image_uri = retrieve(\n",
    "    \"PyTorch\",\n",
    "    \"1.12\",\n",
    "    accelerator_type=\"GPU\",\n",
    ").image_uri\n",
    "\n",
    "print(image_uri)\n",
    "\n",
    "\n",
    "# 训练作业启动命令，通过环境变量的方式获取:\n",
    "# a）输入输出的模型/数据路径\n",
    "# b）训练任务的超参数\n",
    "command = \"\"\"accelerate launch --mixed_precision=\"fp16\"  train_text_to_image_lora.py \\\n",
    "  --pretrained_model_name_or_path=$PAI_INPUT_PRETRAINED_MODEL  \\\n",
    "  --train_data_dir=$PAI_INPUT_TRAIN_DATA \\\n",
    "  --output_dir=$PAI_OUTPUT_MODEL \\\n",
    "  --logging_dir=$PAI_OUTPUT_TENSORBOARD \\\n",
    "  --dataloader_num_workers=8 \\\n",
    "  --resolution=512 --center_crop --random_flip \\\n",
    "  --train_batch_size=$PAI_HPS_TRAIN_BATCH_SIZE \\\n",
    "  --gradient_accumulation_steps=$PAI_HPS_GRADIENT_ACCUMULATION_STEPS \\\n",
    "  --max_train_steps=$PAI_HPS_MAX_TRAIN_STEPS \\\n",
    "  --learning_rate=$PAI_HPS_LEARNING_RATE \\\n",
    "  --checkpointing_steps=$PAI_HPS_CHECKPOINTING_STEPS \\\n",
    "  --max_grad_norm=1 \\\n",
    "  --lr_scheduler=\"cosine\" --lr_warmup_steps=0 \\\n",
    "  --validation_prompt=\"$PAI_HPS_VALIDATION_PROMPT\" \\\n",
    "  --validation_epochs=$PAI_HPS_VALIDATION_EPOCHS \\\n",
    "  --seed=$PAI_HPS_SEED\"\"\"\n",
    "\n",
    "\n",
    "# 训练作业超参\n",
    "hps = {\n",
    "    \"validation_prompt\": \"a photo of cat in a bucket\",  # 验证模型的Prompt\n",
    "    \"validation_epochs\": 1,  # 每隔50个epoch验证一次\n",
    "    \"max_train_steps\": 10,  # 最大训练步数\n",
    "    \"learning_rate\": 1e-4,  # 学习率\n",
    "    \"train_batch_size\": 2,  # 训练batch size\n",
    "    \"gradient_accumulation_steps\": 1,  # 梯度累积步数\n",
    "    \"checkpointing_steps\": 5,  # 每隔100个step保存一次模型\n",
    "    \"seed\": 1337,  # 随机种子\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "以下代码中，我们使用`Estimator`类，指定训练作业使用的镜像，训练作业超参，输入数据路径等，将LoRA训练作业提交到PAI执行。\n",
    "\n",
    "对于使用SDK提交训练作业的详细介绍，用户可以参考文档: [提交训练作业](https://help.aliyun.com/document_detail/2261505.html)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.estimator import Estimator\n",
    "from pai.image import retrieve\n",
    "\n",
    "# 使用PAI提供的PyTorch 1.12 GPU镜像\n",
    "image_uri = retrieve(\n",
    "    \"PyTorch\",\n",
    "    \"1.12\",\n",
    "    accelerator_type=\"GPU\",\n",
    ").image_uri\n",
    "\n",
    "print(image_uri)\n",
    "\n",
    "\n",
    "# 训练作业启动命令，通过环境变量的方式获取:\n",
    "# a）输入输出的模型/数据路径\n",
    "# b）训练任务的超参数\n",
    "\n",
    "command = \"\"\"accelerate launch --mixed_precision=\"fp16\"  train_text_to_image_lora.py \\\n",
    "  --pretrained_model_name_or_path=$PAI_INPUT_PRETRAINED_MODEL  \\\n",
    "  --train_data_dir=$PAI_INPUT_TRAIN_DATA \\\n",
    "  --output_dir=$PAI_OUTPUT_MODEL \\\n",
    "  --logging_dir=$PAI_OUTPUT_TENSORBOARD \\\n",
    "  --dataloader_num_workers=8 \\\n",
    "  --resolution=512 --center_crop --random_flip \\\n",
    "  --train_batch_size=$PAI_HPS_TRAIN_BATCH_SIZE \\\n",
    "  --gradient_accumulation_steps=$PAI_HPS_GRADIENT_ACCUMULATION_STEPS \\\n",
    "  --max_train_steps=$PAI_HPS_MAX_TRAIN_STEPS \\\n",
    "  --learning_rate=$PAI_HPS_LEARNING_RATE \\\n",
    "  --checkpointing_steps=$PAI_HPS_CHECKPOINTING_STEPS \\\n",
    "  --max_grad_norm=1 \\\n",
    "  --lr_scheduler=\"cosine\" --lr_warmup_steps=0 \\\n",
    "  --validation_prompt=\"$PAI_HPS_VALIDATION_PROMPT\" \\\n",
    "  --validation_epochs=$PAI_HPS_VALIDATION_EPOCHS \\\n",
    "  --seed=$PAI_HPS_SEED\"\"\"\n",
    "\n",
    "\n",
    "# 训练作业超参\n",
    "hps = {\n",
    "    \"validation_prompt\": \"a photo of cat in a bucket\",  # 验证模型的Prompt\n",
    "    \"validation_epochs\": 1,  # 每隔50个epoch验证一次\n",
    "    \"max_train_steps\": 10,  # 最大训练步数\n",
    "    \"learning_rate\": 1e-4,  # 学习率\n",
    "    \"train_batch_size\": 2,  # 训练batch size\n",
    "    \"gradient_accumulation_steps\": 1,  # 梯度累积步数\n",
    "    \"checkpointing_steps\": 5,  # 每隔100个step保存一次模型\n",
    "    \"seed\": 1337,  # 随机种子\n",
    "}\n",
    "\n",
    "\n",
    "est = Estimator(\n",
    "    image_uri=image_uri,  # 训练作业使用的镜像\n",
    "    source_dir=\"train_lora\",  # 训练代码路径，代码会被上传，并准备到训练作业环境中\n",
    "    command=command,  # 训练任务启动命令\n",
    "    instance_type=\"ecs.gn6i-c4g1.xlarge\",  # 4 vCPU, 16 GiB 内存, 1 x NVIDIA T4 GPU\n",
    "    base_job_name=\"sd_lora_t2i_\",  # 训练作业名称前缀\n",
    "    hyperparameters=hps,  # 作业超参，训练命令和脚本可以通过 `PAI_HPS_{HP_NAME_UPPER_CASE}` 环境变量，或是读取`/ml/input/config/hpyerparameters.json`文件获取\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "使用`inputs`参数指定准备到训练作业环境的模型和数据，提交训练作业。 \n",
    "\n",
    "`inputs`参数是一个字典，Key是输入的名称，Value是输入数据的存储路径（例如OSS URI)。相应的数据会被准备到作业执行环境中（通过挂载的方式），训练作业脚本，能够通过环境变量`PAI_INPUT_{KeyUpperCase}`获取到输入数据的路径，通过读取本地文件的方式读取预训练模型和数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Input PreTrainedModel: \", model.uri)\n",
    "print(\"Input TrainData: \", train_data_uri)\n",
    "\n",
    "\n",
    "# 提交训练作业\n",
    "est.fit(\n",
    "    inputs={\n",
    "        \"pretrained_model\": model.uri,\n",
    "        \"train_data\": train_data_uri,\n",
    "    },\n",
    "    wait=True,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在启动命令中，我们使用`--output_dir=$PAI_OUTPUT_MODEL`，让训练脚本将模型写出到指定的输出目录中。对应的模型数据会被保存到用户的OSS Bucket中，我们可以通过`est.model_data()`获得输出的模型的OSS URI。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pai.common.oss_utils import download\n",
    "\n",
    "print(\"OutputModel Path: \", est.model_data())\n",
    "lora_weight_uri = os.path.join(est.model_data(), \"pytorch_lora_weight.bin\")\n",
    "lora_model_path = download(oss_path=lora_weight_uri, local_path=\"./lora_model/\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "以上训练获得LoRA模型，可以使用diffuser的推理pipeline加载使用：\n",
    "\n",
    "```python\n",
    "# StableDiffusionPipeline加载LoRA模型\n",
    "\n",
    "\n",
    "import torch\n",
    "from diffusers import StableDiffusionPipeline\n",
    "\n",
    "# 加载基础模型\n",
    "model_id_or_path = \"<SdModelId_Or_Path>\"\n",
    "pipe = StableDiffusionPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)\n",
    "\n",
    "# 加载LoRA模型\n",
    "pipe.unet.load_attn_procs(lora_model_path)\n",
    "\n",
    "# 使用推理pipeline\n",
    "image = pipe(\n",
    "    \"A pokemon with blue eyes.\", num_inference_steps=25, guidance_scale=7.5,\n",
    "    cross_attention_kwargs={\"scale\": 0.5},\n",
    ").images[0]\n",
    "\n",
    "\n",
    "```\n",
    "\n",
    "或则用户也可以将其转为safetensor格式，在StableDiffusiuson WebUI中使用。\n",
    "\n",
    "\n",
    "```python\n",
    "import torch\n",
    "from safetensors.torch import save_file\n",
    "\n",
    "# 加载模型\n",
    "lora_model = torch.load(lora_model_bin_path, map_location=\"cpu\")\n",
    "\n",
    "# 转换为safetensor格式\n",
    "save_file(lora_model, \"lora.safetensors\")\n",
    "\n",
    "```"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LoRA && DreamBooth微调训练\n",
    "\n",
    "### DreamBooth简介\n",
    "\n",
    "DreamBooth是Google的研究人员于2022年提出的技术，支持在少量的图片上进行训练，然后将自定义的主题注入到扩散模型中。\n",
    "\n",
    "![](./resource/dreambooth.jpeg)\n",
    "\n",
    "图片来源: https://dreambooth.github.io/\n",
    "\n",
    "直接使用少量的图片文本数据集对扩散模型进行训练容易导致过拟合，或是语言漂移。DreamBooth使用以下方式避免了模型的退化：\n",
    "\n",
    "- 用户需要为新的主题选择一个罕见的词（标识符），模型将在训练过程中将这个词和图片的主题进行关联。\n",
    "\n",
    "- 为了避免过拟合和语言漂移，微调过程中，使用相同类别的图片参与训练（这些图片可以由扩散模型生成）。\n",
    "\n",
    "对于DreamBooth的详细介绍，用户可以参考[DreamBooth的博客](https://dreambooth.github.io/)，以及[HuggingFace博客](https://huggingface.co/blog/dreambooth)对于DreamBooth的介绍。\n",
    "\n",
    "当通过DreamBooth训练扩散模型时，用户可以选择进行普通的微调（直接微调原始的模型参数），也可以使用LoRA的方式进行微调，在以下示例中，我们将使用LoRA的方式，进行DreamBooth训练。\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 准备训练数据集\n",
    "\n",
    "为了训练DreamBooth，用户需要准备特定风格的图片数据集，当前示例中，我们准备了数据集在`sks-dog`目录下。\n",
    "\n",
    "通过以下代码，我们将将数据集上传到OSS上，供训练作业使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.common.oss_utils import upload\n",
    "\n",
    "train_data_uri = upload(\"sks-dog\", \"stable_diffusion/dreambooth/train-sks-dog/\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 准备训练代码\n",
    "\n",
    "我们使用HuggingFace Diffusers库提供的训练脚本，通过LoRA && DreamBooth方式训练扩散模型。通过以下代码，我们下载训练脚本到本地。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建训练脚本保存路径\n",
    "!mkdir -p train_dreambooth/\n",
    "\n",
    "# 下载HuggingFace diffusers(v1.17.1)库提供的示例代码（因为访问GitHub的网络并不稳定，用户当出现下载失败，可以多尝试几次）\n",
    "!wget https://raw.githubusercontent.com/huggingface/diffusers/v0.17.1/examples/dreambooth/train_dreambooth_lora.py -P train_dreambooth/"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练作业将使用PAI提供的PyTorch镜像运行脚本，我们需要通过以下的`requirements.txt`安装训练脚本依赖的库。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile train_dreambooth/requirements.txt\n",
    "# %%writefile 指令会将当前内容写入到 train_dreambooth/requirements.txt 文件中\n",
    "\n",
    "diffusers>=0.17.1\n",
    "\n",
    "# source: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt\n",
    "accelerate>=0.16.0,<=0.18.0     # diffusers 提供的示例代码(v0.17.1)无法运行在accelerate>=0.18.0上.\n",
    "torchvision\n",
    "transformers>=4.25.1,<5.0.0\n",
    "ftfy\n",
    "tensorboard\n",
    "Jinja2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 提交训练作业\n",
    "\n",
    "通过以下代码，我们使用PAI Python SDK，提交训练作业到PAI。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.estimator import Estimator\n",
    "from pai.image import retrieve\n",
    "\n",
    "image_uri = retrieve(\n",
    "    \"PyTorch\",\n",
    "    \"latest\",\n",
    "    accelerator_type=\"GPU\",\n",
    ").image_uri\n",
    "\n",
    "\n",
    "# 训练作业启动命令，通过环境变量的方式获取:\n",
    "# a）输入输出的模型/数据路径\n",
    "# b）训练任务的超参数\n",
    "command = \"\"\"accelerate launch train_dreambooth_lora.py \\\n",
    "  --pretrained_model_name_or_path=$PAI_INPUT_PRETRAINED_MODEL  \\\n",
    "  --instance_data_dir=$PAI_INPUT_TRAIN_DATA \\\n",
    "  --output_dir=$PAI_OUTPUT_MODEL \\\n",
    "  --logging_dir=$PAI_OUTPUT_TENSORBOARD \\\n",
    "  --instance_prompt=\"$PAI_HPS_INSTANCE_PROMPT\" \\\n",
    "  --resolution=512 \\\n",
    "  --train_batch_size=$PAI_HPS_TRAIN_BATCH_SIZE \\\n",
    "  --gradient_accumulation_steps=$PAI_HPS_GRADIENT_ACCUMULATION_STEPS \\\n",
    "  --checkpointing_steps=$PAI_HPS_CHECKPOINTING_STEPS \\\n",
    "  --learning_rate=$PAI_HPS_LEARNING_RATE \\\n",
    "  --lr_scheduler=\"constant\" \\\n",
    "  --lr_warmup_steps=0 \\\n",
    "  --max_train_steps=$PAI_HPS_MAX_TRAIN_STEPS \\\n",
    "  --validation_prompt=\"$PAI_HPS_VALIDATION_PROMPT\" \\\n",
    "  --validation_epochs=$PAI_HPS_VALIDATION_EPOCHS \\\n",
    "  --seed=\"0\"\n",
    "  \"\"\"\n",
    "\n",
    "# 训练作业超参\n",
    "hps = {\n",
    "    \"instance_prompt\": \"a photo of sks dog\",  # 训练的图片数据文本使用的标注Prompt。这里的sks是我们使用的数据集的特定风格标识符。\n",
    "    \"validation_prompt\": \"a photo of sks dog in a bucket\",  # 验证模型的Prompt\n",
    "    # \"class_prompt\": \"a photo of dog\",  # 用于生成类别图片数据，避免模型过拟合&&语言偏移\n",
    "    \"validation_epochs\": 50,  # 每隔50个epoch验证一次\n",
    "    \"max_train_steps\": 500,  # 最大训练步数\n",
    "    \"learning_rate\": 1e-4,  # 学习率\n",
    "    \"train_batch_size\": 1,  # 训练batch size\n",
    "    \"gradient_accumulation_steps\": 1,  # 梯度累积步数\n",
    "    \"checkpointing_steps\": 100,  # 每隔100个step保存一次模型\n",
    "}\n",
    "\n",
    "\n",
    "est = Estimator(\n",
    "    image_uri=image_uri,  # 训练作业使用的镜像\n",
    "    source_dir=\"train_dreambooth\",  # 训练代码路径，代码会被上传，并准备到训练作业环境中\n",
    "    command=command,  # 训练任务启动命令\n",
    "    instance_type=\"ecs.gn6i-c4g1.xlarge\",  # 4 vCPU, 16 GiB 内存, 1 x NVIDIA T4 GPU\n",
    "    base_job_name=\"sd_lora_dreambooth_\",  # 训练作业名称前缀\n",
    "    hyperparameters=hps,  # 作业超参，训练命令和脚本可以通过 `PAI_HPS_{HP_NAME_UPPER_CASE}` 环境变量，或是读取`/ml/input/config/hpyerparameters.json`文件获取\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Input PreTrainedModel: \", model.uri)\n",
    "print(\"Input TrainData: \", train_data_uri)\n",
    "\n",
    "est.fit(\n",
    "    inputs={\n",
    "        \"pretrained_model\": model.uri,\n",
    "        \"train_data\": train_data_uri,\n",
    "    },\n",
    "    wait=True,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练任务会在输出目录下，生成一个`pytorch_lora_weights.bin`的模型文件，相应的文件会被上传准备到用户的OSS中，用户可以通过以下的代码，将模型文件下载到本地。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import posixpath\n",
    "\n",
    "from pai.common.oss_utils import download\n",
    "\n",
    "# 输出模型路径\n",
    "output_lora_model = posixpath.join(est.model_data(), \"pytorch_lora_weights.bin\")\n",
    "print(\"OutputModel: \", output_lora_model)\n",
    "\n",
    "model_path = download(output_lora_model, \"./lora_dreambooth_model/\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "获得的LoRA模型，用户可以通过Diffuser提供的API，在推理pipeline加载使用，用户可以参考diffuser的文档：[DreamBooth Inference](https://huggingface.co/docs/diffusers/training/lora#dreambooth-inference)。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 结语\n",
    "\n",
    "通过当前示例，我们介绍了如何基于HuggingFace diffusers库，在PAI上完成StableDiffusion模型的LoRA微调训练。用户可以通过Diffuers库的API，直接在推理Pipeline中加载使用这些LoRA模型，也可以将模型转换Safetensors格式，用于StableDiffusionWebUI中。\n",
    "\n",
    "除了对于LoRA的支持，Diffusers库支持对于直接对扩散模型微调，也支持包括TextInversion, ControlNet, InstructPix2Pix等方式训练扩散模型，并且提供了相应的训练脚本和教程。用户同样可以参考本示例，在PAI运行这些训练任务。\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参考\n",
    "\n",
    "- HuggingFace LoRa Tutorial: https://huggingface.co/docs/diffusers/training/lora#texttoimage\n",
    "\n",
    "- HuggingFace LoRA Blog: https://huggingface.co/blog/lora\n",
    "\n",
    "- Google DreamBooth Blog：https://dreambooth.github.io/"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
